#!/usr/bin/python
"""A script to train a regularized SVD model
on a given numpy record array. The data type is assumed to 
be rsvd.rating_t. A struct containing an uint16, an uint32 and an uint8.
"""
import sys
import getopt
import numpy as np

import rsvd

__version__ = rsvd.__version__
__author__  = rsvd.__author__
__license__ = rsvd.__license__


class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg

def usage():
    print """Usage: rsvd_train [options] training_array num_movies num_users output_dir
    
Trains a regularized SVD solver on the given training data and stores the 
trained model in `output_dir`. The training data is assumed to be 
a serialized numpy record array. `num_movies` and `num_users` are the number 
of movies and users, resp., in the data set.
See <http://code.google.com/p/pyrsvd/> for further information. 

Options: 
  -h, --help\tPrint this help
  -f <int>\tThe number of latent factors to compute. 
  -l <float>\tLearn rate [default: 0.001]
  -r <float>\tRegularization parameter [default: 0.11]
  --probe <file>\tEstimate error on probe set. 
\t\t<file> contains the probe set in a numpy record array.
\t\tIf defined, early stopping is turned on
  --maxepochs\tThe max number of epochs to perform
  --randomize\tShuffle the training data every 10th epoch.
  --minimprovement\tThe min improvement in RMSE on the probeset to trigger early stopping [default: 0.000001].

For bug reporting, please mail to:
<peter.prettenhofer@gmail.com>
"""


def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "hl:r:f:", \
                                           ["help","probe=","maxepochs=",\
                                                "minimprovement=","randomize"])
        except getopt.error, msg:
             raise Usage(msg)
        
        lr=0.001
        reg=0.011
        probeFile=None
        ratingsFile=None
        max_epochs=100
        randomize=False
        factors=10
        min_improvement=0.000001
        for o, a in opts:
            if o in ("-h", "--help"):
                usage()
                sys.exit()
            elif o in ("--probe"):
                probeFile=a
            elif o in ("-l"):
                lr=float(a)
            elif o in ("-r"):
                reg=float(a)
            elif o in ("-f"):
                factors=int(a)
            elif o in ("--maxepochs"):
                max_epochs=int(a)
            elif o in ("--randomize"):
                randomize=True
            elif o in ("--minimprovement"):
                min_improvement=float(a)
            else:
                raise Usage("unhandled option")
        
        if len(args) < 4:
            raise Usage("missing arguments")
        try:
            ratingsFile=open(args[0],'r')
            nmovies=int(args[1])
            nusers=int(args[2])
            output=args[3]
            print >>sys.stdout, "Loading training data...\t",
            sys.stdout.flush()
            ratingsArray=np.fromfile(ratingsFile,dtype=rsvd.rating_t)
            print >>sys.stdout, "done."
            probeArray=None
            if probeFile:
                print >>sys.stdout, "Loading probe data...\t",
                sys.stdout.flush()
                probeArray=np.fromfile(file(probeFile),dtype=rsvd.rating_t)
                print >>sys.stdout, "done."
            model=rsvd.RSVD.train(factors,ratingsArray,(nmovies,nusers),\
                             probeArray=probeArray,\
                             maxEpochs=max_epochs,\
                             minImprovement=min_improvement,\
                             learnRate=lr,\
                             regularization=reg,\
                             randomize=randomize)
            
            model.save(output)
        except Exception, e:
            print >>sys.stderr, "Error: ",e
        
    except Usage, err:
        print >>sys.stderr, "Error: ",err.msg
        usage()
        return 2

if __name__ == "__main__":
    sys.exit(main())
